<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
    <title>WebGPU Compute Shaders - Histogram, optimized more</title>
    <style>
      @import url(resources/webgpu-lesson.css);
      canvas {
        display: block;
        max-width: 256px;
        border: 1px solid #888;
        background-color: #333;
      }
    </style>
  </head>
  <body>
  </body>
  <script type="module">
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#wgpu-matrix
import { mat4 } from '../3rdparty/wgpu-matrix.module.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#webgpu-utils
import {
  loadImageBitmap,
  createTextureFromSource,
} from '../3rdparty/webgpu-utils-1.x.module.js';

async function main() {
  const adapter = await navigator.gpu?.requestAdapter();
  const device = await adapter?.requestDevice();
  if (!device) {
    fail('need a browser that supports WebGPU');
    return;
  }

  const presentationFormat = navigator.gpu.getPreferredCanvasFormat();

  const k = {
    chunkWidth: 256,
    chunkHeight: 1,
  };
  const chunkSize = k.chunkWidth * k.chunkHeight;
  const sharedConstants = Object.entries(k).map(([k, v]) => `const ${k} = ${v};`).join('\n');

  const histogramChunkModule = device.createShaderModule({
    label: 'histogram chunk shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;
      var<workgroup> bins: array<array<atomic<u32>, 4>, chunkSize>;
      @group(0) @binding(0) var<storage, read_write> chunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var ourTexture: texture_2d<f32>;

      const kSRGBLuminanceFactors = vec3f(0.2126, 0.7152, 0.0722);
      fn srgbLuminance(color: vec3f) -> f32 {
        return saturate(dot(color, kSRGBLuminanceFactors));
      }

      @compute @workgroup_size(chunkWidth, chunkHeight, 1)
      fn cs(
        @builtin(workgroup_id) workgroup_id: vec3u,
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        let size = textureDimensions(ourTexture, 0);
        let position = workgroup_id.xy * vec2u(chunkWidth, chunkHeight) + 
                       local_invocation_id.xy;
        if (all(position < size)) {
          let numBins = f32(chunkSize);
          let lastBinIndex = u32(numBins - 1);
          var channels = textureLoad(ourTexture, position, 0);
          channels.w = srgbLuminance(channels.rgb);
          for (var ch = 0; ch < 4; ch++) {
            let v = channels[ch];
            let bin = min(u32(v * numBins), lastBinIndex);
            atomicAdd(&bins[bin][ch], 1u);
          }
        }

        workgroupBarrier();

        let chunksAcross = (size.x + chunkWidth - 1) / chunkWidth;
        let chunk = workgroup_id.y * chunksAcross + workgroup_id.x;
        let bin = local_invocation_id.y * chunkWidth + local_invocation_id.x;

        chunks[chunk][bin] = vec4u(
          atomicLoad(&bins[bin][0]),
          atomicLoad(&bins[bin][1]),
          atomicLoad(&bins[bin][2]),
          atomicLoad(&bins[bin][3]),
        );
      }
    `,
  });

  const chunkSumModule = device.createShaderModule({
    label: 'chunk sum shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;

      struct Uniforms {
        stride: u32,
      };

      @group(0) @binding(0) var<storage, read_write> chunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var<uniform> uni: Uniforms;

      @compute @workgroup_size(chunkSize, 1, 1) fn cs(
        @builtin(local_invocation_id) local_invocation_id: vec3u,
        @builtin(workgroup_id) workgroup_id: vec3u,
      ) {
        let chunk0 = workgroup_id.x * uni.stride * 2;
        let chunk1 = chunk0 + uni.stride;

        let sum = chunks[chunk0][local_invocation_id.x] +
                  chunks[chunk1][local_invocation_id.x];
        chunks[chunk0][local_invocation_id.x] = sum;
      }
    `,
  });

  const scaleModule = device.createShaderModule({
    label: 'histogram scale shader',
    code: /* wgsl */ `
      @group(0) @binding(0) var<storage, read> bins: array<vec4u>;
      @group(0) @binding(1) var<storage, read_write> scale: vec4f;
      @group(0) @binding(2) var ourTexture: texture_2d<f32>;

      @compute @workgroup_size(1, 1, 1) fn cs() {
        let size = textureDimensions(ourTexture, 0);
        let numEntries = f32(size.x * size.y);
        var m = vec4u(0);
        let numBins = arrayLength(&bins);
        for (var i = 0u ; i < numBins; i++) {
          m = max(m, bins[i]);
        }
        scale = max(1.0 / vec4f(m), vec4f(0.2 * f32(numBins) / numEntries));
      }
    `,
  });

  const drawHistogramModule = device.createShaderModule({
    label: 'draw histogram shader',
    code: /* wgsl */ `
      struct OurVertexShaderOutput {
        @builtin(position) position: vec4f,
        @location(0) texcoord: vec2f,
      };

      struct Uniforms {
        matrix: mat4x4f,
        colors: array<vec4f, 16>,
        channelMult: vec4u,
      };

      @group(0) @binding(0) var<storage, read> bins: array<vec4u>;
      @group(0) @binding(1) var<uniform> uni: Uniforms;
      @group(0) @binding(2) var<storage, read_write> scale: vec4f;

      @vertex fn vs(
        @builtin(vertex_index) vertexIndex : u32
      ) -> OurVertexShaderOutput {
        let pos = array(

          vec2f( 0.0,  0.0),  // center
          vec2f( 1.0,  0.0),  // right, center
          vec2f( 0.0,  1.0),  // center, top

          // 2st triangle
          vec2f( 0.0,  1.0),  // center, top
          vec2f( 1.0,  0.0),  // right, center
          vec2f( 1.0,  1.0),  // right, top
        );

        var vsOutput: OurVertexShaderOutput;
        let xy = pos[vertexIndex];
        vsOutput.position = uni.matrix * vec4f(xy, 0.0, 1.0);
        vsOutput.texcoord = xy;
        return vsOutput;
      }

      @fragment fn fs(fsInput: OurVertexShaderOutput) -> @location(0) vec4f {
        let numBins = arrayLength(&bins);
        let lastBinIndex = u32(numBins - 1);
        let bin = clamp(
            u32(fsInput.texcoord.x * f32(numBins)),
            0,
            lastBinIndex);
        let heights = vec4f(bins[bin]) * scale;
        let bits = heights > vec4f(fsInput.texcoord.y);
        let ndx = dot(select(vec4u(0), uni.channelMult, bits), vec4u(1));
        return uni.colors[ndx];
      }
    `,
  });

  const histogramChunkPipeline = device.createComputePipeline({
    label: 'histogram',
    layout: 'auto',
    compute: {
      module: histogramChunkModule,
    },
  });

  const chunkSumPipeline = device.createComputePipeline({
    label: 'chunk sum',
    layout: 'auto',
    compute: {
      module: chunkSumModule,
    },
  });

  const scalePipeline = device.createComputePipeline({
    label: 'scale',
    layout: 'auto',
    compute: {
      module: scaleModule,
    },
  });

  const drawHistogramPipeline = device.createRenderPipeline({
    label: 'draw histogram',
    layout: 'auto',
    vertex: {
      module: drawHistogramModule,
    },
    fragment: {
      module: drawHistogramModule,
      targets: [{ format: presentationFormat }],
    },
  });

  const imgBitmap = await loadImageBitmap('resources/images/pexels-chevanon-photography-1108099.jpg'); /* webgpufundamentals: url */
  const texture = createTextureFromSource(device, imgBitmap);

  const chunksAcross = Math.ceil(texture.width / k.chunkWidth);
  const chunksDown = Math.ceil(texture.height / k.chunkHeight);
  const numChunks = chunksAcross * chunksDown;
  const chunksBuffer = device.createBuffer({
    size: numChunks * chunkSize * 4 * 4,  // 16 bytes per (vec4u)
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });

  const scaleBuffer = device.createBuffer({
    size: 4 * 4,
    usage: GPUBufferUsage.STORAGE,
  });

  const histogramBindGroup = device.createBindGroup({
    layout: histogramChunkPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer }},
      { binding: 1, resource: texture.createView() },
    ],
  });

  const scaleBindGroup = device.createBindGroup({
    layout: scalePipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer, size: chunkSize * 4 * 4 }},
      { binding: 1, resource: { buffer: scaleBuffer }},
      { binding: 2, resource: texture.createView() },
    ],
  });

  const sumBindGroups = [];
  const numSteps = Math.ceil(Math.log2(numChunks));
  for (let i = 0; i < numSteps; ++i) {
    const stride = 2 ** i;
    const uniformBuffer = device.createBuffer({
      size: 4,
      usage: GPUBufferUsage.UNIFORM,
      mappedAtCreation: true,
    });
    new Uint32Array(uniformBuffer.getMappedRange()).set([stride]);
    uniformBuffer.unmap();

    const chunkSumBindGroup = device.createBindGroup({
      layout: chunkSumPipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer }},
        { binding: 1, resource: { buffer: uniformBuffer }},
      ],
    });
    sumBindGroups.push(chunkSumBindGroup);
  }

  const encoder = device.createCommandEncoder({ label: 'histogram encoder' });
  const pass = encoder.beginComputePass();

  // create a histogram for each chunk
  pass.setPipeline(histogramChunkPipeline);
  pass.setBindGroup(0, histogramBindGroup);
  pass.dispatchWorkgroups(chunksAcross, chunksDown);

  // reduce the chunks
  pass.setPipeline(chunkSumPipeline);
  let chunksLeft = numChunks;
  sumBindGroups.forEach(bindGroup => {
    pass.setBindGroup(0, bindGroup);
    const dispatchCount = Math.floor(chunksLeft / 2);
    chunksLeft -= dispatchCount;
    pass.dispatchWorkgroups(dispatchCount);
  });

  // Compute scales for the channels
  pass.setPipeline(scalePipeline);
  pass.setBindGroup(0, scaleBindGroup);
  pass.dispatchWorkgroups(1);
  pass.end();

  const commandBuffer = encoder.finish();
  device.queue.submit([commandBuffer]);

  showImageBitmap(imgBitmap);

  // draw the red, green, and blue channels
  drawHistogram([0, 1, 2]);

  // draw the luminosity channel
  drawHistogram([3]);

  function drawHistogram(channels, height = 100) {
    const numBins = chunkSize;

    //  matrix: mat4x4f;
    //  colors: array<vec4f, 16>;
    //  channelMult; vec4u,
    const uniformValuesAsF32 = new Float32Array(16 + 64 + 4 + 4);
    const uniformValuesAsU32 = new Uint32Array(uniformValuesAsF32.buffer);
    const uniformBuffer = device.createBuffer({
      label: 'draw histogram uniform buffer',
      size: uniformValuesAsF32.byteLength,
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });
    const subpart = (view, offset, length) => view.subarray(offset, offset + length);
    const matrix = subpart(uniformValuesAsF32, 0, 16);
    const colors = subpart(uniformValuesAsF32, 16, 64);
    const channelMult = subpart(uniformValuesAsU32, 16 + 64, 4);
    colors.set([
      [0, 0, 0, 1],
      [1, 0, 0, 1],
      [0, 1, 0, 1],
      [1, 1, 0, 1],
      [0, 0, 1, 1],
      [1, 0, 1, 1],
      [0, 1, 1, 1],
      [0.5, 0.5, 0.5, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1],
    ].flat());
    const range = (i, fn) => new Array(i).fill(0).map((_, i) => fn(i));
    channelMult.set(range(4, i => channels.indexOf(i) >= 0 ? 2 ** i : 0));

    mat4.identity(matrix);
    mat4.translate(matrix, [-1, -1, 0], matrix);
    mat4.scale(matrix, [2, 2, 1], matrix);
    device.queue.writeBuffer(uniformBuffer, 0, uniformValuesAsF32);

    const bindGroup = device.createBindGroup({
      layout: drawHistogramPipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer, size: chunkSize * 4 * 4 }},
        { binding: 1, resource: { buffer: uniformBuffer } },
        { binding: 2, resource: { buffer: scaleBuffer }},
      ],
    });

    const canvas = document.createElement('canvas');
    const context = canvas.getContext('webgpu');
    context.configure({
      device,
      format: presentationFormat,
    });
    canvas.width = numBins;
    canvas.height = height;
    document.body.appendChild(canvas);

    // Get the current texture from the canvas context and
    // set it as the texture to render to.
    const renderPassDescriptor = {
      label: 'our basic canvas renderPass',
      colorAttachments: [
        {
          view: context.getCurrentTexture().createView(),
          clearValue: [0.3, 0.3, 0.3, 1],
          loadOp: 'clear',
          storeOp: 'store',
        },
      ],
    };

    const encoder = device.createCommandEncoder({ label: 'render histogram' });
    const pass = encoder.beginRenderPass(renderPassDescriptor);
    pass.setPipeline(drawHistogramPipeline);
    pass.setBindGroup(0, bindGroup);
    pass.draw(6);  // call our vertex shader 6 times
    pass.end();

    const commandBuffer = encoder.finish();
    device.queue.submit([commandBuffer]);
  }
}

function showImageBitmap(imageBitmap) {
  const canvas = document.createElement('canvas');

  // we have to see the canvas size because of a firefox bug
  // https://bugzilla.mozilla.org/show_bug.cgi?id=1850871
  canvas.width = imageBitmap.width;
  canvas.height = imageBitmap.height;

  const bm = canvas.getContext('bitmaprenderer');
  bm.transferFromImageBitmap(imageBitmap);
  document.body.appendChild(canvas);
}

function fail(msg) {
  // eslint-disable-next-line no-alert
  alert(msg);
}

main();
  </script>
</html>
